-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add changes for e2e inference for passing test cases #1127
Conversation
18bbb27
to
c57a0a3
Compare
|
|
# Post processing | ||
probabilities = torch.nn.functional.softmax(output[0][0], dim=0) | ||
|
||
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" | ||
urllib.request.urlretrieve(url, "imagenet_classes.txt") | ||
|
||
with open("imagenet_classes.txt", "r") as f: | ||
categories = [s.strip() for s in f.readlines()] | ||
top5_prob, top5_catid = torch.topk(probabilities, 5) | ||
for i in range(top5_prob.size(0)): | ||
print(categories[top5_catid[i]], top5_prob[i].item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've noticed that this pattern arises in other models as well. Can we put this post processing part inside utils
folder e.g. ghostnet/utils/utils.py
where we would have method something like this:
def post_processing(top_k, url, file_name):
# Post processing
probabilities = torch.nn.functional.softmax(output[0][0], dim=0)
urllib.request.urlretrieve(url, file_name)
with open(file_name, "r") as f:
categories = [s.strip() for s in f.readlines()]
topk_prob, topk_catid = torch.topk(probabilities, top_k)
for i in range(topk_prob.size(0)):
print(categories[topkcatid[i]], topk_prob[i].item())
Also let's do this for other models as well :)
# Create model | ||
model = MobileNetV1(9) | ||
model.eval() | ||
|
||
# Load data sample | ||
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") | ||
urllib.request.urlretrieve(url, filename) | ||
|
||
# Preprocessing | ||
input_image = Image.open(filename) | ||
preprocess = transforms.Compose( | ||
[ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | ||
] | ||
) | ||
input_tensor = preprocess(input_image) | ||
input_batch = input_tensor.unsqueeze(0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initialization of all mobilenet_vx
models is almost the same, can we also put that function in .../mobilenet/utils/
folder so we can reuse it for other mobilenets as well.
bd924fc
to
13e9f5b
Compare
|
|
def post_processing(output, url="https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"): | ||
|
||
probabilities = torch.nn.functional.softmax(output[0][0], dim=0) | ||
urllib.request.urlretrieve(url, "imagenet_classes.txt") | ||
|
||
with open("imagenet_classes.txt", "r") as f: | ||
categories = [s.strip() for s in f.readlines()] | ||
top5_prob, top5_catid = torch.topk(probabilities, 5) | ||
for i in range(top5_prob.size(0)): | ||
print(categories[top5_catid[i]], top5_prob[i].item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please pass the number of top k categories as a param to the function (where default is 5) because we might not always want the top 5 categories. Apply this change in other utils as well :)
a81ea0c
to
a470a9f
Compare
top5_prob, top5_catid = torch.topk(probabilities, top_k) | ||
for i in range(top5_prob.size(0)): | ||
print(categories[top5_catid[i]], top5_prob[i].item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rename the top5_prob
and top5_catid
to topk_prob
and topk_catid
in this and other files as well.
|
|
8361b90
to
87f7c99
Compare
|
|
|
||
|
||
@pytest.mark.nightly | ||
@pytest.mark.parametrize("variant", variants, ids=variants) | ||
@pytest.mark.parametrize("variant", params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls rename back to variants
to maintain the standard format
|
||
|
||
@pytest.mark.nightly | ||
@pytest.mark.parametrize("variant", variants, ids=variants) | ||
@pytest.mark.parametrize("variant", params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here too
|
||
|
||
variants = ["wide_resnet50_2", "wide_resnet101_2"] | ||
params = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
categories = [s.strip() for s in f.readlines()] | ||
topk_prob, topk_catid = torch.topk(probabilities, top_k) | ||
for i in range(topk_prob.size(0)): | ||
print(categories[topk_catid[i]], topk_prob[i].item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please cleanup downloaded imagenet_classes.txt file after successful post processing
|
||
# STEP 3: Prepare input | ||
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") | ||
urllib.request.urlretrieve(url, filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please cleanup downloaded image file after successful processing. Also pls update for other models too
4cdb2b0
to
274a0f8
Compare
|
|
eb45ecf
to
0d602c6
Compare
0d602c6
to
eb4f134
Compare
|
1 similar comment
|
|
1 similar comment
|
### Summary - End to end inference changes added for passing test cases present in Mobilenet v1, Mobilenet v2, Mobilenet v3, Resnext, wideresnet , ghostnet & DLA Models - Push marker added for vilt & all above models
Summary